#数据访问层

from bson import ObjectId
import pymongo
from pymongo.collection import Collection
import utils
from entities import  Device, User, MessageData, Session


DB_HOST = "localhost"
DB_PORT = 27017
DB_NAME = "iot"
FIELD_ID = "_id"
OBJ_ID = "id"
def db_collection(coll_name)->Collection:
    #获取一个集合
    client:pymongo.MongoClient = pymongo.MongoClient(host=DB_HOST, port=DB_PORT)
    return client[DB_NAME][coll_name]


def insert_one(coll: Collection, data: dict)->str:
    #插入记录\
    _id = data.get(OBJ_ID)
    if _id==None:
        _id = ObjectId() ##新生成一个ObjectId
    else:
        if isinstance(_id, str):
            _id = ObjectId(_id) ##把字符串转换成ObjectId
    data[FIELD_ID] = _id
    if OBJ_ID in data:
        data.pop(OBJ_ID)
    coll.insert_one(data)
    return str(_id)

def update_one(coll: Collection, filter, data: dict):
    return coll.update_one(filter=filter,update={"$set": data})
    
def find_one_by_id(coll: Collection, _id: str):
    return coll.find_one({FIELD_ID: ObjectId(_id)})

def update_one_by_id(coll: Collection, _id, data):
    if isinstance(_id, str):
        _id = ObjectId(_id) ##把字符串转换成ObjectId
    coll.update_one(filter={FIELD_ID:_id},update={"$set": data})

def check_session_index(flag={}):
    if flag.get("index"):
        return #索引已经存在，直接返回
    #判断索引是否存在
    coll = db_collection(Session.COLL_NAME)
    info = coll.index_information()
    if info.get(Session.INDEX_SESSION):
        #如果索引存在，设置检查标志，然后返回
        flag["index"] = 1
        return
    #如果索引不存在，创建一个过期自动删除记录的索引
    coll.create_index([(Session.FIELD_LAST_ACTIVE,1)], name=Session.INDEX_SESSION, expireAfterSeconds=Session.TIME_OUT_SECONDS)
    
               
def create_session(obj_id: str, obj_type: int)->Session:
    #创建会话对象
    check_session_index()
    coll = db_collection(Session.COLL_NAME)
    #删除已有的会话，一次只允许一个地方登录，后登陆的会把前面的顶掉
    coll.delete_one({Session.FIELD_OBJ_ID: obj_id, Session.FIELD_OBJ_TYPE: obj_type})
    #新建一个会话
    session = Session()
    session.obj_id = obj_id
    session.obj_type = obj_type
    session.new_token()
    session.update()
    insert_one(coll, session.to_json())
    return session


def get_session(token: str, obj_type: int)->Session:
    #获取会话对象
    coll = db_collection(Session.COLL_NAME)
    data = coll.find_one({Session.FIELD_TOKEN: token, Session.FIELD_OBJ_TYPE: obj_type})
    if data:
        session = Session()
        session.from_json(data)
        session.update() #更新最后的活动时间
        #获取会话的同时更新时间戳，用upsert避免在更新时刚好删掉
        coll.update_one({FIELD_ID: data.get(FIELD_ID)},
                        {"$set":
                         {Session.FIELD_TOKEN: token,
                          Session.FIELD_OBJ_ID: session.obj_id,
                          Session.FIELD_OBJ_TYPE: obj_type,
                          Session.FIELD_LAST_ACTIVE: session.last_active
                          }}, upsert=True)
        return session
    return None

def get_session_by_obj_id(obj_id: str, obj_type: int)->Session:
    #获取会话对象
    coll = db_collection(Session.COLL_NAME)
    data = coll.find_one({Session.FIELD_OBJ_ID: obj_id, Session.FIELD_OBJ_TYPE: obj_type})
    if data:
        session = Session()
        session.from_json(data)
        return session
    return None

def set_session_data(token: str, obj_type: int, data: dict):
    update_one(coll=db_collection(Session.COLL_NAME), 
               filter={Session.FIELD_TOKEN: token, Session.FIELD_OBJ_TYPE: obj_type},
               data={'data':data})

def get_session_data(token: str, obj_type: int):
    #获取会话对象
    coll = db_collection(Session.COLL_NAME)
    data = coll.find_one({Session.FIELD_TOKEN: token, Session.FIELD_OBJ_TYPE: obj_type})
    if data:
        return data.get('data')
    return None

def get_user_id(name: str, pwd: str)->str:
    #用户登录，如果用户名密码正确，返回用户的ID，用于会话
    if pwd==None:
        pwd = ""
    md5_pwd = utils.md5_string(pwd)
    coll = db_collection(User.COLL_NAME)
    rec = coll.find_one({User.FIELD_NAME:name, User.FIELD_PWD: md5_pwd})
    if rec:
        return str(rec.get(FIELD_ID))
    return None

def get_user_by_id(user_id)->User:
    coll = db_collection(User.COLL_NAME)
    rec = find_one_by_id(coll, user_id)
    if rec:
        user = User()
        user.from_json(rec)
        return user
    return None

def update_user(user: User):
    coll = db_collection(User.COLL_NAME)
    update_one_by_id(coll, user._id, user.to_json())

def get_device_id(key: str):
    #设备登录，返回设备ID，用于会话
    if key==None:
        return None
    coll = db_collection(Device.COLL_NAME)
    rec = coll.find_one({Device.FIELD_KEY:key})
    if rec:
        return str(rec.get(FIELD_ID))
    return None #设备key不存在返回None

def get_last_message_data(device_id: str, action: str=None)->MessageData:
    #获取最后的消息
    coll = db_collection(MessageData.COLL_NAME)
    filter = {MessageData.FIELD_DEVICE_ID:device_id}
    if action!=None:
        filter[MessageData.FIELD_ACTION] = action
    data = coll.find_one(filter=filter, sort=[(MessageData.FIELD_TIME, pymongo.DESCENDING)])
    if data:
        msg = MessageData()
        msg.from_json(data)
        return msg
    return None

def get_message_data_list(device_id, page_size=20, page_no=0, sender=None, action=None):
    #获取消息列表
    coll = db_collection(MessageData.COLL_NAME)
    filter = {MessageData.FIELD_DEVICE_ID:device_id}
    if sender!=None:
        filter[MessageData.FIELD_SENDER] = sender
    if action!=None:
        filter[MessageData.FIELD_ACTION] = action
    if page_size<=0:
        page_size = 20
    if page_no <=0:
        page_no = 1
    count =  coll.count_documents(filter=filter)
    skip = (page_no - 1) * page_size
    recs = coll.find(filter=filter).sort([(MessageData.FIELD_TIME, pymongo.DESCENDING)]).skip(skip).limit(page_size)
    msgs = []
    for rec in recs:
        msg = MessageData()
        msg.from_json(rec)
        msgs.append(msg)
    return count, msgs 

def insert_message_data(msg: MessageData)->str:
    coll = db_collection(MessageData.COLL_NAME)
    return insert_one(coll, msg.to_json())

def get_devices_by_user(user_id)->list[Device]:
    ret = []
    user = get_user_by_id(user_id=user_id)
    if user and user.devices!=None:
        for device_id in user.devices:
            device = get_device_by_id(device_id=device_id)
            if device:
                ret.append(device)
    return ret

def get_device_by_id(device_id)->Device:
    coll = db_collection(coll_name=Device.COLL_NAME)
    data = find_one_by_id(coll=coll, _id=device_id)
    if data!=None:
        device = Device()
        device.from_json(data)
        return device
    return None

def get_device_by_active_code(active_code: str)->Device:
    #根据激活码查找设备
    coll = db_collection(coll_name=Device.COLL_NAME)
    data = coll.find_one(filter={Device.FIELD_ACTIVE_CODE:active_code})
    if data:
        device = Device()
        device.from_json(data)
        return device
    return None

def update_device(device: Device):
    #更新设备
    coll = db_collection(Device.COLL_NAME)
    update_one_by_id(coll, device._id, device.to_json())
 
def test02():
    #测试代码
    obj_id = str(ObjectId())
    session = create_session(obj_id)
    print(session.token, session.last_active)
    import time
    for i in range(20):
        time.sleep(3)
        session = get_session(session.token)
        print(session.token, session.last_active)

       
if __name__ == "__main__":
    pass
    #test02()
    #test_get_device_data()

    
